import argparse
import glob
import os
import sys
import time

import dgl
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from model import compute_acc_unsupervised as compute_acc, SAGE

from negative_sampler import NegativeSampler
from pytorch_lightning import LightningDataModule, LightningModule, Trainer

from pytorch_lightning.callbacks import Callback, ModelCheckpoint

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from load_graph import inductive_split, load_ogb, load_reddit


class CrossEntropyLoss(nn.Module):
    def forward(self, block_outputs, pos_graph, neg_graph):
        with pos_graph.local_scope():
            pos_graph.ndata["h"] = block_outputs
            pos_graph.apply_edges(fn.u_dot_v("h", "h", "score"))
            pos_score = pos_graph.edata["score"]
        with neg_graph.local_scope():
            neg_graph.ndata["h"] = block_outputs
            neg_graph.apply_edges(fn.u_dot_v("h", "h", "score"))
            neg_score = neg_graph.edata["score"]

        score = th.cat([pos_score, neg_score])
        label = th.cat(
            [th.ones_like(pos_score), th.zeros_like(neg_score)]
        ).long()
        loss = F.binary_cross_entropy_with_logits(score, label.float())
        return loss


class SAGELightning(LightningModule):
    def __init__(
        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout, lr
    ):
        super().__init__()
        self.save_hyperparameters()
        self.module = SAGE(
            in_feats, n_hidden, n_classes, n_layers, activation, dropout
        )
        self.lr = lr
        self.loss_fcn = CrossEntropyLoss()

    def training_step(self, batch, batch_idx):
        input_nodes, pos_graph, neg_graph, mfgs = batch
        mfgs = [mfg.int().to(device) for mfg in mfgs]
        pos_graph = pos_graph.to(device)
        neg_graph = neg_graph.to(device)
        batch_inputs = mfgs[0].srcdata["features"]
        batch_labels = mfgs[-1].dstdata["labels"]
        batch_pred = self.module(mfgs, batch_inputs)
        loss = self.loss_fcn(batch_pred, pos_graph, neg_graph)
        self.log(
            "train_loss", loss, prog_bar=True, on_step=False, on_epoch=True
        )
        return loss

    def validation_step(self, batch, batch_idx):
        input_nodes, output_nodes, mfgs = batch
        mfgs = [mfg.int().to(device) for mfg in mfgs]
        batch_inputs = mfgs[0].srcdata["features"]
        batch_labels = mfgs[-1].dstdata["labels"]
        batch_pred = self.module(mfgs, batch_inputs)
        return batch_pred

    def configure_optimizers(self):
        optimizer = th.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer


class DataModule(LightningDataModule):
    def __init__(
        self,
        dataset_name,
        data_cpu=False,
        fan_out=[10, 25],
        device=th.device("cpu"),
        batch_size=1000,
        num_workers=4,
    ):
        super().__init__()
        if dataset_name == "reddit":
            g, n_classes = load_reddit()
            n_edges = g.num_edges()
            reverse_eids = th.cat(
                [th.arange(n_edges // 2, n_edges), th.arange(0, n_edges // 2)]
            )
        elif dataset_name == "ogbn-products":
            g, n_classes = load_ogb("ogbn-products")
            n_edges = g.num_edges()
            # The reverse edge of edge 0 in OGB products dataset is 1.
            # The reverse edge of edge 2 is 3.  So on so forth.
            reverse_eids = th.arange(n_edges) ^ 1
        else:
            raise ValueError("unknown dataset")

        train_nid = th.nonzero(g.ndata["train_mask"], as_tuple=True)[0]
        val_nid = th.nonzero(g.ndata["val_mask"], as_tuple=True)[0]
        test_nid = th.nonzero(
            ~(g.ndata["train_mask"] | g.ndata["val_mask"]), as_tuple=True
        )[0]

        sampler = dgl.dataloading.MultiLayerNeighborSampler(
            [int(_) for _ in fan_out]
        )

        dataloader_device = th.device("cpu")
        if not data_cpu:
            train_nid = train_nid.to(device)
            val_nid = val_nid.to(device)
            test_nid = test_nid.to(device)
            g = g.formats(["csc"])
            g = g.to(device)
            dataloader_device = device

        self.g = g
        self.train_nid, self.val_nid, self.test_nid = (
            train_nid,
            val_nid,
            test_nid,
        )
        self.sampler = sampler
        self.device = dataloader_device
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.in_feats = g.ndata["features"].shape[1]
        self.n_classes = n_classes
        self.reverse_eids = reverse_eids

    def train_dataloader(self):
        sampler = dgl.dataloading.as_edge_prediction_sampler(
            self.sampler,
            exclude="reverse_id",
            reverse_eids=self.reverse_eids,
            negative_sampler=NegativeSampler(
                self.g, args.num_negs, args.neg_share
            ),
        )
        return dgl.dataloading.DataLoader(
            self.g,
            np.arange(self.g.num_edges()),
            sampler,
            device=self.device,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=False,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        # Note that the validation data loader is a DataLoader
        # as we want to evaluate all the node embeddings.
        return dgl.dataloading.DataLoader(
            self.g,
            np.arange(self.g.num_nodes()),
            self.sampler,
            device=self.device,
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=self.num_workers,
        )


class UnsupervisedClassification(Callback):
    def on_validation_epoch_start(self, trainer, pl_module):
        self.val_outputs = []

    def on_validation_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
    ):
        self.val_outputs.append(outputs)

    def on_validation_epoch_end(self, trainer, pl_module):
        node_emb = th.cat(self.val_outputs, 0)
        g = trainer.datamodule.g
        labels = g.ndata["labels"]
        f1_micro, f1_macro = compute_acc(
            node_emb,
            labels,
            trainer.datamodule.train_nid,
            trainer.datamodule.val_nid,
            trainer.datamodule.test_nid,
        )
        pl_module.log("val_f1_micro", f1_micro)


if __name__ == "__main__":
    argparser = argparse.ArgumentParser("multi-gpu training")
    argparser.add_argument("--gpu", type=int, default=0)
    argparser.add_argument("--dataset", type=str, default="reddit")
    argparser.add_argument("--num-epochs", type=int, default=20)
    argparser.add_argument("--num-hidden", type=int, default=16)
    argparser.add_argument("--num-layers", type=int, default=2)
    argparser.add_argument("--num-negs", type=int, default=1)
    argparser.add_argument(
        "--neg-share",
        default=False,
        action="store_true",
        help="sharing neg nodes for positive nodes",
    )
    argparser.add_argument("--fan-out", type=str, default="10,25")
    argparser.add_argument("--batch-size", type=int, default=10000)
    argparser.add_argument("--log-every", type=int, default=20)
    argparser.add_argument("--eval-every", type=int, default=1000)
    argparser.add_argument("--lr", type=float, default=0.003)
    argparser.add_argument("--dropout", type=float, default=0.5)
    argparser.add_argument(
        "--num-workers",
        type=int,
        default=0,
        help="Number of sampling processes. Use 0 for no extra process.",
    )
    args = argparser.parse_args()

    if args.gpu >= 0:
        device = th.device("cuda:%d" % args.gpu)
    else:
        device = th.device("cpu")

    datamodule = DataModule(
        args.dataset,
        True,
        [int(_) for _ in args.fan_out.split(",")],
        device,
        args.batch_size,
        args.num_workers,
    )
    model = SAGELightning(
        datamodule.in_feats,
        args.num_hidden,
        datamodule.n_classes,
        args.num_layers,
        F.relu,
        args.dropout,
        args.lr,
    )

    # Train
    unsupervised_callback = UnsupervisedClassification()
    checkpoint_callback = ModelCheckpoint(monitor="val_f1_micro", save_top_k=1)
    trainer = Trainer(
        gpus=[args.gpu] if args.gpu != -1 else None,
        max_epochs=args.num_epochs,
        val_check_interval=1000,
        callbacks=[checkpoint_callback, unsupervised_callback],
        num_sanity_val_steps=0,
    )
    trainer.fit(model, datamodule=datamodule)
